import os
import wave
import unittest

import pytest

from io import BytesIO

from pyglet.media.synthesis import *


local_dir = os.path.dirname(__file__)
test_data_path = os.path.abspath(os.path.join(local_dir, '..', '..', 'data'))
del local_dir


def get_test_data_file(*file_parts):
    """Get a file from the test data directory in an OS independent way.

    Supply relative file name as you would in os.path.join().
    """
    return os.path.join(test_data_path, *file_parts)


classes = [Silence, WhiteNoise, Sine, Triangle, Sawtooth, Square]
sample_rates = [44800, 11025]


@pytest.mark.parametrize('source_class', classes)
@pytest.mark.parametrize('sample_rate', sample_rates)
def test_instantiation(source_class, sample_rate):
    source = source_class(duration=1, sample_rate=sample_rate)


@pytest.mark.parametrize('source_class', classes)
@pytest.mark.parametrize('sample_rate', sample_rates)
def test_total_duration(source_class, sample_rate):
    source = source_class(duration=1)                       # One second of audio...
    expected_bytes = source.audio_format.bytes_per_second   # should match this.
    audio_data = source.get_audio_data(expected_bytes + 100)

    assert audio_data.length == pytest.approx(expected_bytes)
    assert audio_data.duration == pytest.approx(1.0)
    assert len(audio_data.data) == pytest.approx(expected_bytes)

    # Should now be out of data
    last_data = source.get_audio_data(100)
    assert last_data is None


@pytest.mark.parametrize('source_class', classes)
@pytest.mark.parametrize('sample_rate', sample_rates)
def test_generated_bytes(source_class, sample_rate):
    if source_class == WhiteNoise:
        return

    source = source_class(duration=1, sample_rate=sample_rate)
    source_name = source_class.__name__.lower()
    filename = "synthesis_{0}_{1}_{2}_1ch.wav".format(source_name, 16, sample_rate)

    with wave.open(get_test_data_file('media', filename)) as f:
        loaded_bytes = f.readframes(-1)
        generated_data = source.get_audio_data(source._max_offset)
        bytes_buffer = BytesIO(generated_data.data).getvalue()
        # Compare a small chunk, to avoid hanging on mismatch:
        assert bytes_buffer[:1000] == loaded_bytes[:1000], "Generated bytes do not match sample wave file."
